/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file matmul_macro_v220_basic_impl.h
 * \brief
 */
#ifndef IMPL_MTAMUL_MATMUL_MACRO_V220_BASIC_IMPL_H
#define IMPL_MTAMUL_MATMUL_MACRO_V220_BASIC_IMPL_H

#include "kernel_operator.h"
#include "lib/matmul/tiling.h"

#define HW_N0 16
#define HW_M0 16
#define ALIGN_NUM 16
#define L0AUF_SIZE 65536
#define L0BUF_SIZE 65536
#define BIAS_BUF_SIZE 1024
#define L0A_PING_D 0
#define L0A_PONG_D (L0AUF_SIZE / 2)
#define L0B_PING_D 0
#define L0B_PONG_D (L0BUF_SIZE / 2)
#define BIAS_PING_D 0
#define BIAS_PONG_D (BIAS_BUF_SIZE / 2)


namespace matmul {
using namespace AscendC;

__aicore__ inline uint16_t CeilDivNum(uint16_t num1, uint16_t num2)
{
    ASSERT(num2 > 0);
    return (num1 + num2 - 1) / num2;
}

__aicore__ inline uint16_t CeilAlignNum(uint16_t num1, uint16_t num2)
{
    ASSERT(num2 > 0);
    return CeilDivNum(num1, num2) * num2;
}

template <typename C_T, typename A_T, typename B_T> constexpr inline __aicore__ uint16_t GetK0Value()
{
    if constexpr (IsSameType<C_T, float>::value && sizeof(A_T) == sizeof(half)) {
        return 16; // constexpr
    } else if constexpr (IsSameType<C_T, float>::value && IsSameType<A_T, float>::value) {
        return 8;
    } else {
        return 32;
    }
}


// ===========mad template=================/
// Cmatrix type, Amatrix type, Bmatrix type, IsBias, mm_cfg, L0C_using_uniflag, L0C_using_hset
template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, bool isBias, const auto &MM_CFG,
    uint16_t UNIFLAG_EN = 0, uint16_t L0AB_USING_HSET = 0>
class MacroMatmulBasic {
public:
    inline __aicore__ MacroMatmulBasic(){};
    // addr
    uint64_t L0A_PING = L0A_PING_D;
    uint64_t L0A_PONG = L0A_PONG_D;
    uint64_t L0B_PING = L0B_PING_D;
    uint64_t L0B_PONG = L0B_PONG_D;
    uint64_t BIAS_PING = BIAS_PING_D;
    uint64_t BIAS_PONG = BIAS_PONG_D;
    // args
    uint16_t sAL1K_;
    uint16_t sBL1K_;
    uint16_t sMad0K_;
    uint16_t sL0cInit_; // 0; normal  1:init
    uint16_t sL0cLast_; // 0; normal  1:last
    // state
    uint16_t ssAl0PingPongFlag_;
    uint16_t ssBl0PingPongFlag_;
    uint16_t kDirectionAlign_;
    // instance args
    // 0:format(M, K)
    // 1:format(K, M), need set transpose
    uint16_t ssAmatrixTranspose1_; // template params
    // 0:format(K, N), use load3dv2 carry
    // 1:format(N, K), use load2d carry
    uint16_t ssBmatrixTranspose1_;
    // 0: no bias
    // 1: fp16
    // 2: fp32
    // need judge inside
    uint16_t biasType_;
    static constexpr uint16_t hwK0_ = GetK0Value<C_T, A_T, B_T>();
    static constexpr uint16_t typeSize_ = sizeof(A_T);

    TBuf<TPosition::A2> l0aBuf_;
    TBuf<TPosition::B2> l0bBuf_;
    TBuf<TPosition::C2> biasBuf_;

    inline __aicore__ void Init();
    inline __aicore__ void Compute(const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
        const LocalTensor<C_T> &cMatrix, const LocalTensor<BIAS_T> &bias);

private:
    constexpr inline __aicore__ uint16_t GetK0Value();
    inline __aicore__ void LoadL12L0A(uint16_t usedK,
        const LocalTensor<A_T> &l1A, LocalTensor<A_T> &l0A);
    inline __aicore__ void LoadL12L0B(const LocalTensor<B_T> &l1B, LocalTensor<B_T> &l0B);
    inline __aicore__ void MmadMacro(const LocalTensor<A_T> &l0A, const LocalTensor<B_T> &l0B,
        const LocalTensor<C_T> &cMatrix, uint16_t mmadK, uint8_t unitFlag);
};

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, bool isBias, const auto &MM_CFG,
    uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulBasic<IMPL, C_T, A_T, B_T, BIAS_T, isBias, MM_CFG, UNIFLAG_EN, L0AB_USING_HSET>::MmadMacro(
    const LocalTensor<A_T> &l0A, const LocalTensor<B_T> &l0B, const LocalTensor<C_T> &cMatrix,
    uint16_t mmadK, uint8_t unitFlag)
{
    MmadParams mmadParams;
    mmadParams.m = ToMatmulConfig(MM_CFG).basicM;
    mmadParams.k = mmadK;
    mmadParams.n = ToMatmulConfig(MM_CFG).basicN;
    mmadParams.unitFlag = unitFlag;
    mmadParams.kDirectionAlign = kDirectionAlign_;
    if constexpr (isBias) {
        mmadParams.cmatrixSource = sL0cInit_;
        mmadParams.cmatrixInitVal = false;
    } else {
        mmadParams.cmatrixSource = false;
        mmadParams.cmatrixInitVal = sL0cInit_;
    }
    Mmad(cMatrix, l0A, l0B, mmadParams);
    if constexpr ((ToMatmulConfig(MM_CFG).basicM / ALIGN_NUM) * (ToMatmulConfig(MM_CFG).basicN / ALIGN_NUM) < 10) {
        PipeBarrier<PIPE_M>();
    }
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, bool isBias, const auto &MM_CFG,
    uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulBasic<IMPL, C_T, A_T, B_T, BIAS_T, isBias, MM_CFG, UNIFLAG_EN, L0AB_USING_HSET>::LoadL12L0A(
    uint16_t usedK, const LocalTensor<A_T> &l1A, LocalTensor<A_T> &l0A)
{
    if (ssAmatrixTranspose1_ > 0) {
        // format(K, M), K, M need to be 16 aligned for f32
        uint16_t usedKAlign; // k value optimization
        if constexpr (ToMatmulConfig(MM_CFG).basicK != 0) {
            constexpr uint16_t align = (ToMatmulConfig(MM_CFG).basicK + HW_M0 - 1) / HW_M0 * HW_M0;
            usedKAlign = align;
        } else {
            usedKAlign = CeilAlignNum(usedK, HW_M0);
        }
        // K_axis is m direction, and M_axis is k direction in load3d intrin
        LoadData3DParamsV2Pro loadData3DV2;
        loadData3DV2.channelSize = ToMatmulConfig(MM_CFG).basicM;
        loadData3DV2.extConfig = ((uint64_t)0 << 48) | ((uint64_t)0 << 32) |
                               ((uint64_t)usedKAlign << 16) | (uint64_t)ToMatmulConfig(MM_CFG).basicM;
        loadData3DV2.enTranspose = true;
#if __CCE_AICORE__ >= 220 && __CCE_AICORE__ != 310
        if constexpr (IsSameType<A_T, bfloat16_t>::value) {
            LoadData(l0A[0], l1A[0], loadData3DV2);
        } else {
            LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
        }
#else
        LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
#endif
    } else {
        // format(M, K), K_axis is k direction, and M_axis is m direction in load3d intrin
        // k direction need to be 8 aligned for f32
        uint16_t usedKAlign;
        if constexpr (ToMatmulConfig(MM_CFG).basicK != 0) {
            constexpr uint16_t align = (ToMatmulConfig(MM_CFG).basicK + hwK0_ - 1) / hwK0_ * hwK0_;;
            usedKAlign = align;
        } else {
            usedKAlign = CeilAlignNum(usedK, HW_M0);
        }

        LoadData3DParamsV2Pro loadData3DV2;
        loadData3DV2.channelSize = sAL1K_;
        loadData3DV2.extConfig = ((uint64_t)0 << 48) | ((uint64_t)0 << 32) |
                               ((uint64_t)ToMatmulConfig(MM_CFG).basicM << 16) | (uint64_t)usedKAlign;
#if __CCE_AICORE__ >= 220 && __CCE_AICORE__ != 310
        if constexpr (IsSameType<A_T, bfloat16_t>::value) {
            LoadData(l0A[0], l1A[0], loadData3DV2);
        } else {
            LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
        }
#else
        LoadData<A_T>(l0A[0], l1A[0], loadData3DV2);
#endif
    }
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, bool isBias, const auto &MM_CFG,
    uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulBasic<IMPL, C_T, A_T, B_T, BIAS_T, isBias, MM_CFG, UNIFLAG_EN, L0AB_USING_HSET>::LoadL12L0B(
    const LocalTensor<B_T> &l1B, LocalTensor<B_T> &l0B)
{
    if (ssBmatrixTranspose1_ > 0) {
        // SET LOAD2D parameters , loop axis: K or M, or 1
        // k is hwK0_ aligned for f32
        constexpr uint16_t nFraC0 = ToMatmulConfig(MM_CFG).basicN / HW_N0;
        uint16_t l0bLoop = 1;
        uint16_t l0bSrcAddrStride = 0;
        uint16_t l0bDstAddrStride = 0;
        uint8_t l0bRepeat = 0;
        uint16_t l0bSrcstride = 1;
        uint16_t l0bDststride = 0;
        if constexpr (ToMatmulConfig(MM_CFG).singleCoreM != 0 && ToMatmulConfig(MM_CFG).singleCoreN != 0) {
            constexpr uint16_t kC0 = (ToMatmulConfig(MM_CFG).basicK + hwK0_ - 1) / hwK0_;
            constexpr uint16_t repeat = kC0 * nFraC0;
            l0bRepeat = repeat;
            if constexpr (nFraC0 * HW_N0 == ToMatmulConfig(MM_CFG).basicN) {
                l0bLoop = 1;            // loop=1
            } else if constexpr (nFraC0 >= kC0) { // LOOP is K  and repeat is n axis
                l0bLoop = kC0;
                constexpr uint16_t srcStride = ToMatmulConfig(MM_CFG).basicN * hwK0_ * typeSize_;
                constexpr uint16_t dstStride = nFraC0 * HW_N0 * hwK0_ * typeSize_;
                l0bSrcAddrStride = srcStride;
                l0bDstAddrStride = dstStride;
                l0bRepeat = nFraC0;

                l0bSrcstride = 1;
                l0bDststride = 0;
            } else { // LOOP is N  and repeat is K axis
                l0bLoop = nFraC0;
                constexpr uint16_t srcStride = HW_N0 * hwK0_ * typeSize_;
                constexpr uint16_t dstStride = HW_N0 * hwK0_ * typeSize_;
                l0bSrcAddrStride = srcStride;
                l0bDstAddrStride = dstStride;
                l0bRepeat = kC0;

                l0bSrcstride = nFraC0;
                l0bDststride = nFraC0 - 1;
            }
        } else {
            uint16_t sMad0KAlign = CeilAlignNum(sMad0K_, hwK0_);
            uint16_t kC0 = sMad0KAlign / hwK0_;
            l0bRepeat = kC0 * nFraC0;

            if constexpr (nFraC0 * HW_N0 == ToMatmulConfig(MM_CFG).basicN) {
                l0bLoop = 1;            // loop=1
            } else if (nFraC0 >= kC0) { // LOOP is K  and repeat is n axis
                l0bLoop = kC0;
                l0bSrcAddrStride = ToMatmulConfig(MM_CFG).basicN * hwK0_ * typeSize_;
                l0bDstAddrStride = nFraC0 * HW_N0 * hwK0_ * typeSize_;
                l0bRepeat = nFraC0;

                l0bSrcstride = 1;
                l0bDststride = 0;
            } else { // LOOP is N  and repeat is K axis
                l0bLoop = nFraC0;
                l0bSrcAddrStride = HW_N0 * hwK0_ * typeSize_;
                l0bDstAddrStride = HW_N0 * hwK0_ * typeSize_;
                l0bRepeat = kC0;

                l0bSrcstride = ToMatmulConfig(MM_CFG).basicN / HW_N0;
                l0bDststride = nFraC0 - 1;
            }
        }
        // use load2d for L1_2_L0B
        LoadData2dParams loadDataParams;
        loadDataParams.repeatTimes = l0bRepeat;
        loadDataParams.srcStride = l0bSrcstride;
        loadDataParams.dstGap = l0bDststride;
        loadDataParams.ifTranspose = 0;
        uint64_t l1bOffset = 0;
        uint64_t l0bOffset = 0;
        for (uint64_t i = 0; i < l0bLoop; i++) {
            LoadData(l0B[l0bOffset], l1B[l1bOffset], loadDataParams);
            l1bOffset += (l0bSrcAddrStride / typeSize_);
            l0bOffset += (l0bDstAddrStride / typeSize_);
        }
    } else {
        // use load3dv2 for L1_2_L0B
        // n_axis is K direction, need to be 16 aligned
        // channel size need to be 16 aligned
        // k_axis is M direction, need to be HW_M0 aligned
        uint16_t mAlign;
        if constexpr (ToMatmulConfig(MM_CFG).basicK != 0) {
            constexpr uint16_t align = (ToMatmulConfig(MM_CFG).basicK + HW_M0 - 1) / HW_M0 * HW_M0;
            mAlign = align;
        } else {
            mAlign = CeilAlignNum(sMad0K_, HW_M0);
        }
        // StepN need to be aligned
        LoadData3DParamsV2Pro loadData3DV2;
        loadData3DV2.channelSize = ToMatmulConfig(MM_CFG).basicN;
        loadData3DV2.extConfig = ((uint64_t)0 << 48) | ((uint64_t)0 << 32) |
                               ((uint64_t)mAlign << 16) | (uint64_t)ToMatmulConfig(MM_CFG).basicN;
        loadData3DV2.fMatrixCtrl = true;
#if __CCE_AICORE__ >= 220 && __CCE_AICORE__ != 310
        if constexpr (IsSameType<B_T, bfloat16_t>::value) {
            LoadData(l0B[0], l1B[0], loadData3DV2);
        } else {
            LoadData<B_T>(l0B[0], l1B[0], loadData3DV2);
        }
#else
        LoadData<B_T>(l0B[0], l1B[0], loadData3DV2);
#endif
    }
}

// initialization
template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, bool isBias, const auto &MM_CFG,
    uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulBasic<IMPL, C_T, A_T, B_T, BIAS_T, isBias, MM_CFG, UNIFLAG_EN, L0AB_USING_HSET>::Init()
{
    if constexpr (unlikely(UNIFLAG_EN)) {
        SetMMLayoutTransform(0);
    }
#ifdef ASCENDC_CPU_DEBUG
    // allocate 64K L0A space for cpu debug
    uint64_t pA = (uint64_t)((__ca__ A_T *)malloc(L0AUF_SIZE)); // use api
    // allocate 64K L0B space for cpu debug
    uint64_t pB = (uint64_t)((__cb__ B_T *)malloc(L0BUF_SIZE));
    uint64_t pBias = (uint64_t)((C_T *)malloc(BIAS_BUF_SIZE));
    L0A_PING += pA;
    L0A_PONG += pA;
    L0B_PING += pB;
    L0B_PONG += pB;
    BIAS_PING += pBias;
    BIAS_PONG += pBias;
#endif
    ssAl0PingPongFlag_ = 0;
    ssBl0PingPongFlag_ = 0;
    ssAmatrixTranspose1_ = 0;
    ssBmatrixTranspose1_ = 0;
    biasType_ = 0; // need pass bias_t to judge if assgin value
    kDirectionAlign_ = 0;
    sL0cInit_ = 1;
    sL0cLast_ = 0;
    GetTPipePtr()->InitBuffer(l0aBuf_, L0AUF_SIZE);
    GetTPipePtr()->InitBuffer(l0bBuf_, L0BUF_SIZE);
    GetTPipePtr()->InitBuffer(biasBuf_, BIAS_BUF_SIZE);
}

template <typename IMPL, typename C_T, typename A_T, typename B_T, typename BIAS_T, bool isBias, const auto &MM_CFG,
    uint16_t UNIFLAG_EN, uint16_t L0AB_USING_HSET>
inline __aicore__ void MacroMatmulBasic<IMPL, C_T, A_T, B_T, BIAS_T, isBias, MM_CFG, UNIFLAG_EN, L0AB_USING_HSET>::Compute(
    const LocalTensor<A_T> &l1AMatrix, const LocalTensor<B_T> &l1BMatrix,
    const LocalTensor<C_T> &cMatrix, const LocalTensor<BIAS_T> &bias)
{
    if (ssAmatrixTranspose1_ > 0) {
        if constexpr (IsSameType<C_T, float>::value && IsSameType<A_T, float>::value) {
            kDirectionAlign_ = 1;
        }
        Load3DSetFMatrixCal(1, sAL1K_, padList);
    } else {
        // fmatrix w should be 16 aligned
        Load3DSetFMatrixCal(1, ToMatmulConfig(MM_CFG).basicM, padList);
    }

    if (ssBmatrixTranspose1_ < 1) {
        Load3DSetFMatrixBCal(1, sBL1K_, padList);
    }
    if constexpr (isBias) {
        if (sL0cInit_) {
            LocalTensor<C_T> biasC2;
            biasC2 = biasBuf_.Get<C_T>();
            if (biasType_ == 1) {
                constexpr uint16_t lenBurst = ToMatmulConfig(MM_CFG).basicN + 31 / 32;
                DataCopy(biasC2, bias, {1, lenBurst, 0, 0});
            } else { // fp32
                constexpr uint16_t lenBurst = ToMatmulConfig(MM_CFG).basicN + 15 / 16;
                DataCopy(biasC2, bias, {1, lenBurst, 0, 0});
            }
        }
    }
    LocalTensor<A_T> l0a;
    LocalTensor<B_T> l0b;
    l0a = l0aBuf_.Get<A_T>();
    l0b = l0bBuf_.Get<B_T>();
    if ((ssAl0PingPongFlag_ & 0x1) != 0) {
        l0a = l0a[L0AUF_SIZE / 2 / sizeof(A_T)];
        l0b = l0b[L0BUF_SIZE / 2 / sizeof(B_T)];
    }

    WaitFlag<HardEvent::M_MTE1>(ssAl0PingPongFlag_ & 0x1);
    // load L0A
    LoadL12L0A(sMad0K_, l1AMatrix, l0a);
    // load L0B
    LoadL12L0B(l1BMatrix, l0b);
    if constexpr (!L0AB_USING_HSET) {
        SetFlag<HardEvent::MTE1_M>(ssAl0PingPongFlag_ & 0x1);
        WaitFlag<HardEvent::MTE1_M>(ssAl0PingPongFlag_ & 0x1);
    }
    // MAD
    uint8_t unitFlag = 0;
    if constexpr (UNIFLAG_EN) {
        unitFlag = (sL0cLast_) ? 3 : 2;
    }
    MmadMacro(l0a, l0b, cMatrix, sMad0K_, unitFlag);
    if constexpr (!L0AB_USING_HSET) {
        SetFlag<HardEvent::M_MTE1>(ssAl0PingPongFlag_ & 0x1);
    }
    // update pingpong flag
    ssAl0PingPongFlag_ += 1;
    ssBl0PingPongFlag_ += 1;
}
} // namespace matmul
#endif